﻿// Copyright (c) Pomelo Foundation. All rights reserved.
// Licensed under the MIT. See LICENSE in the project root for license information.

using System.IO;
using System;

using Pomelo.Data.Common;
using System.Runtime.InteropServices;
using System.Diagnostics;
using System.Security;

namespace Pomelo.Data.MySql.Authentication
{
    /// <summary>
    /// 
    /// </summary>
#if NETFW
    [SuppressUnmanagedCodeSecurityAttribute()]
#endif
    internal class MySqlWindowsAuthenticationPlugin : MySqlAuthenticationPlugin
    {
        SECURITY_HANDLE outboundCredentials = new SECURITY_HANDLE(0);
        SECURITY_HANDLE clientContext = new SECURITY_HANDLE(0);
        SECURITY_INTEGER lifetime = new SECURITY_INTEGER(0);
        bool continueProcessing;
        string targetName = null;

        protected override void CheckConstraints()
        {
            string platform = String.Empty;
#if NETFW
            int p = (int)Environment.OSVersion.Platform;
            if ((p == 4) || (p == 128))
                platform = "Unix";
            else if (Environment.OSVersion.Platform == PlatformID.MacOSX)
                platform = "Mac OS/X";

            if (!String.IsNullOrEmpty(platform))
                throw new MySqlException(String.Format(Resources.WinAuthNotSupportOnPlatform, platform));
#endif
            base.CheckConstraints();
        }

        public override string GetUsername()
        {
            string username = base.GetUsername();
            if (String.IsNullOrEmpty(username))
                return "auth_windows";
            return username;
        }

        public override string PluginName
        {
            get { return "authentication_windows_client"; }
        }

        protected override byte[] MoreData(byte[] moreData)
        {
            if (moreData == null)
                AcquireCredentials();

            byte[] clientBlob = null;

            if (continueProcessing)
                InitializeClient(out clientBlob, moreData, out continueProcessing);

            if (!continueProcessing || clientBlob == null || clientBlob.Length == 0)
            {
                FreeCredentialsHandle(ref outboundCredentials);
                DeleteSecurityContext(ref clientContext);
                return null;
            }
            return clientBlob;
        }

        void InitializeClient(out byte[] clientBlob, byte[] serverBlob, out bool continueProcessing)
        {
            clientBlob = null;
            continueProcessing = true;
            SecBufferDesc clientBufferDesc = new SecBufferDesc(MAX_TOKEN_SIZE);
            SECURITY_INTEGER initLifetime = new SECURITY_INTEGER(0);
            int ss = -1;
            try
            {
                uint ContextAttributes = 0;

                if (serverBlob == null)
                {
                    ss = InitializeSecurityContext(
                        ref outboundCredentials,
                        IntPtr.Zero,
                        targetName,
                        STANDARD_CONTEXT_ATTRIBUTES,
                        0,
                        SECURITY_NETWORK_DREP,
                        IntPtr.Zero, /* always zero first time around */
                        0,
                        out clientContext,
                        out clientBufferDesc,
                        out ContextAttributes,
                        out initLifetime);

                }
                else
                {
                    SecBufferDesc serverBufferDesc = new SecBufferDesc(serverBlob);

                    try
                    {
                        ss = InitializeSecurityContext(ref outboundCredentials,
                            ref clientContext,
                            targetName,
                            STANDARD_CONTEXT_ATTRIBUTES,
                            0,
                            SECURITY_NETWORK_DREP,
                            ref serverBufferDesc,
                            0,
                            out clientContext,
                            out clientBufferDesc,
                            out ContextAttributes,
                            out initLifetime);
                    }
                    finally
                    {
                        serverBufferDesc.Dispose();
                    }
                }


                if ((SEC_I_COMPLETE_NEEDED == ss)
                    || (SEC_I_COMPLETE_AND_CONTINUE == ss))
                {
                    CompleteAuthToken(ref clientContext, ref clientBufferDesc);
                }

                if (ss != SEC_E_OK &&
                    ss != SEC_I_CONTINUE_NEEDED &&
                    ss != SEC_I_COMPLETE_NEEDED &&
                    ss != SEC_I_COMPLETE_AND_CONTINUE)
                {
                    throw new MySqlException(
                        "InitializeSecurityContext() failed  with errorcode " + ss);
                }

                clientBlob = clientBufferDesc.GetSecBufferByteArray();
            }
            finally
            {
                clientBufferDesc.Dispose();
            }
            continueProcessing = (ss != SEC_E_OK && ss != SEC_I_COMPLETE_NEEDED);
        }

        /// <summary>
        /// Currently this method is unused
        /// </summary>
        /// <returns></returns>
        private string GetTargetName()
        {
            return null;
            /*if (AuthenticationData == null) return String.Empty;

            int index = -1;
            for (int i = 0; i < AuthenticationData.Length; i++)
            {
              if (AuthenticationData[i] != 0) continue;
              index = i;
              break;
            }
            if (index == -1)
      #if RT
              targetName = System.Text.Encoding.UTF8.GetString(AuthenticationData, 0, AuthenticationData.Length);
      #else
              targetName = System.Text.Encoding.UTF8.GetString(AuthenticationData);
      #endif
            else
                targetName = System.Text.Encoding.UTF8.GetString(AuthenticationData, 0, index);
            return targetName;*/
        }

        private void AcquireCredentials()
        {

            continueProcessing = true;

            int ss = AcquireCredentialsHandle(null, "Negotiate", SECPKG_CRED_OUTBOUND,
                    IntPtr.Zero, IntPtr.Zero, 0, IntPtr.Zero, ref outboundCredentials,
                    ref lifetime);
            if (ss != SEC_E_OK)
                throw new MySqlException("AcquireCredentialsHandle failed with errorcode" + ss);
        }

        #region SSPI Constants and Imports

        const int SEC_E_OK = 0;
        const int SEC_I_CONTINUE_NEEDED = 0x90312;
        const int SEC_I_COMPLETE_NEEDED = 0x1013;
        const int SEC_I_COMPLETE_AND_CONTINUE = 0x1014;

        const int SECPKG_CRED_OUTBOUND = 2;
        const int SECURITY_NETWORK_DREP = 0;
        const int SECURITY_NATIVE_DREP = 0x10;
        const int SECPKG_CRED_INBOUND = 1;
        const int MAX_TOKEN_SIZE = 12288;
        const int SECPKG_ATTR_SIZES = 0;
        const int STANDARD_CONTEXT_ATTRIBUTES = 0;

        [DllImport("secur32", CharSet = CharSet.Unicode)]
        static extern int AcquireCredentialsHandle(
            string pszPrincipal,
            string pszPackage,
            int fCredentialUse,
            IntPtr PAuthenticationID,
            IntPtr pAuthData,
            int pGetKeyFn,
            IntPtr pvGetKeyArgument,
            ref SECURITY_HANDLE phCredential,
            ref SECURITY_INTEGER ptsExpiry);

        [DllImport("secur32", CharSet = CharSet.Unicode, SetLastError = true)]
        static extern int InitializeSecurityContext(
            ref SECURITY_HANDLE phCredential,
            IntPtr phContext,
            string pszTargetName,
            int fContextReq,
            int Reserved1,
            int TargetDataRep,
            IntPtr pInput,
            int Reserved2,
            out SECURITY_HANDLE phNewContext,
            out SecBufferDesc pOutput,
            out uint pfContextAttr,
            out SECURITY_INTEGER ptsExpiry);

        [DllImport("secur32", CharSet = CharSet.Unicode, SetLastError = true)]
        static extern int InitializeSecurityContext(
            ref SECURITY_HANDLE phCredential,
            ref SECURITY_HANDLE phContext,
            string pszTargetName,
            int fContextReq,
            int Reserved1,
            int TargetDataRep,
            ref SecBufferDesc SecBufferDesc,
            int Reserved2,
            out SECURITY_HANDLE phNewContext,
            out SecBufferDesc pOutput,
            out uint pfContextAttr,
            out SECURITY_INTEGER ptsExpiry);

        [DllImport("secur32", CharSet = CharSet.Unicode, SetLastError = true)]
        static extern int CompleteAuthToken(
            ref SECURITY_HANDLE phContext,
            ref SecBufferDesc pToken);

        [DllImport("secur32.Dll", CharSet = CharSet.Unicode, SetLastError = false)]
        public static extern int QueryContextAttributes(
            ref SECURITY_HANDLE phContext,
            uint ulAttribute,
            out SecPkgContext_Sizes pContextAttributes);

        [DllImport("secur32.Dll", CharSet = CharSet.Unicode, SetLastError = false)]
        public static extern int FreeCredentialsHandle(ref SECURITY_HANDLE pCred);

        [DllImport("secur32.Dll", CharSet = CharSet.Unicode, SetLastError = false)]
        public static extern int DeleteSecurityContext(ref SECURITY_HANDLE pCred);

        #endregion
    }

    [StructLayout(LayoutKind.Sequential)]
    struct SecBufferDesc : IDisposable
    {

        public int ulVersion;
        public int cBuffers;
        public IntPtr pBuffers; //Point to SecBuffer

        public SecBufferDesc(int bufferSize)
        {
            ulVersion = (int)SecBufferType.SECBUFFER_VERSION;
            cBuffers = 1;
            SecBuffer secBuffer = new SecBuffer(bufferSize);
            pBuffers = Marshal.AllocHGlobal(Marshal.SizeOf(secBuffer));
            Marshal.StructureToPtr(secBuffer, pBuffers, false);
        }

        public SecBufferDesc(byte[] secBufferBytes)
        {
            ulVersion = (int)SecBufferType.SECBUFFER_VERSION;
            cBuffers = 1;
            SecBuffer ThisSecBuffer = new SecBuffer(secBufferBytes);
            pBuffers = Marshal.AllocHGlobal(Marshal.SizeOf(ThisSecBuffer));
            Marshal.StructureToPtr(ThisSecBuffer, pBuffers, false);
        }

        public void Dispose()
        {
            if (pBuffers != IntPtr.Zero)
            {
                Debug.Assert(cBuffers == 1);
#if NET452 || DNX452 || NETSTANDARD1_3
        SecBuffer ThisSecBuffer =
            (SecBuffer)Marshal.PtrToStructure<SecBuffer>(pBuffers);
#else
                SecBuffer ThisSecBuffer = (SecBuffer)Marshal.PtrToStructure(pBuffers, typeof(SecBuffer));
#endif
                ThisSecBuffer.Dispose();
                Marshal.FreeHGlobal(pBuffers);
                pBuffers = IntPtr.Zero;
            }
        }

        public byte[] GetSecBufferByteArray()
        {
            byte[] Buffer = null;

            if (pBuffers == IntPtr.Zero)
            {
                throw new InvalidOperationException("Object has already been disposed!!!");
            }
            Debug.Assert(cBuffers == 1);
#if NET452 || DNX452 || NETSTANDARD1_3
      SecBuffer secBuffer = (SecBuffer)Marshal.PtrToStructure<SecBuffer>(pBuffers);
#else
            SecBuffer secBuffer = (SecBuffer)Marshal.PtrToStructure(pBuffers, typeof(SecBuffer));
#endif
            if (secBuffer.cbBuffer > 0)
            {
                Buffer = new byte[secBuffer.cbBuffer];
                Marshal.Copy(secBuffer.pvBuffer, Buffer, 0, secBuffer.cbBuffer);
            }
            return (Buffer);
        }

    }

    public enum SecBufferType
    {
        SECBUFFER_VERSION = 0,
        SECBUFFER_EMPTY = 0,
        SECBUFFER_DATA = 1,
        SECBUFFER_TOKEN = 2
    }

    [StructLayout(LayoutKind.Sequential)]
    public struct SecHandle //=PCtxtHandle
    {
        IntPtr dwLower; // ULONG_PTR translates to IntPtr not to uint
        IntPtr dwUpper; // this is crucial for 64-Bit Platforms
    }

    [StructLayout(LayoutKind.Sequential)]
    public struct SecBuffer : IDisposable
    {
        public int cbBuffer;
        public int BufferType;
        public IntPtr pvBuffer;


        public SecBuffer(int bufferSize)
        {
            cbBuffer = bufferSize;
            BufferType = (int)SecBufferType.SECBUFFER_TOKEN;
            pvBuffer = Marshal.AllocHGlobal(bufferSize);
        }

        public SecBuffer(byte[] secBufferBytes)
        {
            cbBuffer = secBufferBytes.Length;
            BufferType = (int)SecBufferType.SECBUFFER_TOKEN;
            pvBuffer = Marshal.AllocHGlobal(cbBuffer);
            Marshal.Copy(secBufferBytes, 0, pvBuffer, cbBuffer);
        }

        public SecBuffer(byte[] secBufferBytes, SecBufferType bufferType)
        {
            cbBuffer = secBufferBytes.Length;
            BufferType = (int)bufferType;
            pvBuffer = Marshal.AllocHGlobal(cbBuffer);
            Marshal.Copy(secBufferBytes, 0, pvBuffer, cbBuffer);
        }

        public void Dispose()
        {
            if (pvBuffer != IntPtr.Zero)
            {
                Marshal.FreeHGlobal(pvBuffer);
                pvBuffer = IntPtr.Zero;
            }
        }
    }

    [StructLayout(LayoutKind.Sequential)]
    public struct SECURITY_INTEGER
    {
        public uint LowPart;
        public int HighPart;
        public SECURITY_INTEGER(int dummy)
        {
            LowPart = 0;
            HighPart = 0;
        }
    };

    [StructLayout(LayoutKind.Sequential)]
    public struct SECURITY_HANDLE
    {
        public IntPtr LowPart;
        public IntPtr HighPart;
        public SECURITY_HANDLE(int dummy)
        {
            LowPart = HighPart = new IntPtr(0);
        }
    };

    [StructLayout(LayoutKind.Sequential)]
    public struct SecPkgContext_Sizes
    {
        public uint cbMaxToken;
        public uint cbMaxSignature;
        public uint cbBlockSize;
        public uint cbSecurityTrailer;
    };

}
